import numpy as np
import matplotlib.pyplot as plt
import os.path as osp
import os

def read_csvs(fpath):
    files = os.listdir(fpath)
    files.remove('plots.py')
    return files, [np.loadtxt(fname=i, delimiter=',')[:, 1:] for i in files if i.endswith('.csv')]
print(read_csvs('.'))

types, data = read_csvs('.')
for type, data in zip(*read_csvs('.')):
    plt.plot(data[:, 0], data[:, -1])
plt.xlabel('epoch')
plt.ylabel('total loss')
plt.legend([type.rstrip('.csv') for type in types])
plt.savefig('ablation.png')
plt.show()